{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Alpha_code_example.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0rklDqluPYZB"
      },
      "source": [
        "## Example of training using the Alpha Neuron"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ptd12vmR7H8D",
        "outputId": "79499811-cc47-4470-f190-f3a8581df051"
      },
      "source": [
        "# !git clone -b alpha_neuron --single-branch https://github.com/jeshraghian/snntorch.git\n",
        "!pip install snntorch"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "fatal: destination path 'snntorch' already exists and is not an empty directory.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5VNVo64L93Zo",
        "outputId": "69255447-8261-4119-f851-066c05ea510c"
      },
      "source": [
        "# %cd snntorch\n",
        "import snntorch as snn"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/snntorch/snntorch\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CEUipC6UPebM"
      },
      "source": [
        "## Import Packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AQeg7QZ69fSt"
      },
      "source": [
        "import snntorch as snn\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "import numpy as np\n",
        "import itertools\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0OEdE6bEPf20"
      },
      "source": [
        "## Define Network and Parameters"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MHmCvayV7LJR"
      },
      "source": [
        "# test alpha neuron: can it learn?\n",
        "\n",
        "num_inputs = 28*28\n",
        "num_hidden = 1000\n",
        "num_outputs = 10\n",
        "\n",
        "# Training Parameters\n",
        "batch_size=128\n",
        "data_path='/data/mnist'\n",
        "\n",
        "# Temporal Dynamics\n",
        "num_steps = 25\n",
        "alpha = 0.9\n",
        "beta = 0.8\n",
        "\n",
        "dtype = torch.float\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
        "\n",
        "# Define Network\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        \n",
        "        # initialize layers\n",
        "        self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
        "        self.lif1 = snn.Alpha(alpha=alpha, beta=beta)\n",
        "        self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
        "        self.lif2 = snn.Alpha(alpha=alpha, beta=beta)\n",
        "\n",
        "\n",
        "    def forward(self, x):\n",
        "        spk1, syn_exc1, syn_inh1, mem1 = self.lif1.init_alpha(batch_size, num_hidden)\n",
        "        spk2, syn_exc2, syn_inh2, mem2 = self.lif2.init_alpha(batch_size, num_outputs)\n",
        "\n",
        "        # Record the final layer\n",
        "        spk2_rec = []\n",
        "        mem2_rec = []\n",
        "\n",
        "        for step in range(num_steps):\n",
        "\n",
        "            cur1 = self.fc1(x)\n",
        "            spk1, syn_exc1, syn_inh1, mem1 = self.lif1(cur1, syn_exc1, syn_inh1, mem1)\n",
        "            cur2 = self.fc2(spk1)\n",
        "            spk2, syn_exc2, syn_inh2, mem2 = self.lif2(cur2, syn_exc2, syn_inh2, mem2)\n",
        "\n",
        "            spk2_rec.append(spk2)\n",
        "            mem2_rec.append(mem2)\n",
        "\n",
        "        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n",
        "  \n",
        "net = Net().to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2N00-2eDPl2G"
      },
      "source": [
        "## dataloaders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NblLifHj9qO-",
        "outputId": "e890995e-a9fd-490d-f278-caf6bce9c21e"
      },
      "source": [
        "# Define a transform\n",
        "transform = transforms.Compose([\n",
        "            transforms.Resize((28, 28)),\n",
        "            transforms.Grayscale(),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0,), (1,))])\n",
        "\n",
        "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
        "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
            "  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kx8nATF69tEk"
      },
      "source": [
        "# Create DataLoaders\n",
        "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gqDmeDJcP_HL"
      },
      "source": [
        "## Print Accuracy Function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EJYEQpHk-MzX"
      },
      "source": [
        "def print_batch_accuracy(data, targets, train=False):\n",
        "    with torch.no_grad():\n",
        "      output, _ = net(data.view(batch_size, -1))\n",
        "      _, idx = output.sum(dim=0).max(1)\n",
        "      acc = np.mean((targets == idx).detach().cpu().numpy())\n",
        "\n",
        "      if train:\n",
        "          print(f\"Train Set Accuracy: {acc}\")\n",
        "      else:\n",
        "          print(f\"Test Set Accuracy: {acc}\")\n",
        "\n",
        "def train_printer():\n",
        "    print(f\"Epoch {epoch}, Minibatch {minibatch_counter}\")\n",
        "    print(f\"Train Set Loss: {loss_hist[counter]}\")\n",
        "    print(f\"Test Set Loss: {test_loss_hist[counter]}\")\n",
        "    print_batch_accuracy(data_it, targets_it, train=True)\n",
        "    print_batch_accuracy(testdata_it, testtargets_it, train=False)\n",
        "    print(\"\\n\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m2Zz5xesQBBk"
      },
      "source": [
        "## Define Loss & Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4Uwvn33m-OuY"
      },
      "source": [
        "optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))\n",
        "log_softmax_fn = nn.LogSoftmax(dim=-1)\n",
        "loss_fn = nn.NLLLoss()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wY-94aYrQCXE"
      },
      "source": [
        "## Training Loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "hQvfXduS-QmB",
        "outputId": "7ec1f1dc-ce3f-4834-f601-9b5c34245225"
      },
      "source": [
        "loss_hist = []\n",
        "test_loss_hist = []\n",
        "counter = 0\n",
        "\n",
        "# Outer training loop\n",
        "for epoch in range(10):\n",
        "    minibatch_counter = 0\n",
        "    train_batch = iter(train_loader)\n",
        "\n",
        "    # Minibatch training loop\n",
        "    for data_it, targets_it in train_batch:\n",
        "        data_it = data_it.to(device)\n",
        "        targets_it = targets_it.to(device)\n",
        "\n",
        "        spk_rec, mem_rec = net(data_it.view(batch_size, -1))\n",
        "        log_p_y = log_softmax_fn(mem_rec)\n",
        "        loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
        "\n",
        "        # Sum loss over time steps: BPTT\n",
        "        for step in range(num_steps):\n",
        "          loss_val += loss_fn(log_p_y[step], targets_it)\n",
        "\n",
        "        # Gradient calculation\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "\n",
        "        # Weight Update\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        "\n",
        "        # Test set\n",
        "        test_data = itertools.cycle(test_loader)\n",
        "        testdata_it, testtargets_it = next(test_data)\n",
        "        testdata_it = testdata_it.to(device)\n",
        "        testtargets_it = testtargets_it.to(device)\n",
        "\n",
        "        # Test set forward pass\n",
        "        with torch.no_grad():\n",
        "          test_spk, test_mem = net(testdata_it.view(batch_size, -1))\n",
        "\n",
        "          # Test set loss\n",
        "          log_p_ytest = log_softmax_fn(test_mem)\n",
        "          log_p_ytest = log_p_ytest.sum(dim=0)\n",
        "          loss_val_test = loss_fn(log_p_ytest, testtargets_it)\n",
        "          test_loss_hist.append(loss_val_test.item())\n",
        "\n",
        "          # Print test/train loss/accuracy\n",
        "          if counter % 50 == 0:\n",
        "              train_printer()\n",
        "          minibatch_counter += 1\n",
        "          counter += 1\n",
        "\n",
        "loss_hist_true_grad = loss_hist\n",
        "test_loss_hist_true_grad = test_loss_hist"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0, Minibatch 0\n",
            "Train Set Loss: 57.23897171020508\n",
            "Test Set Loss: 56.607032775878906\n",
            "Train Set Accuracy: 0.1171875\n",
            "Test Set Accuracy: 0.0625\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 50\n",
            "Train Set Loss: 39.31348419189453\n",
            "Test Set Loss: 37.7827033996582\n",
            "Train Set Accuracy: 0.3203125\n",
            "Test Set Accuracy: 0.3046875\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 100\n",
            "Train Set Loss: 27.320207595825195\n",
            "Test Set Loss: 26.991588592529297\n",
            "Train Set Accuracy: 0.40625\n",
            "Test Set Accuracy: 0.390625\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 150\n",
            "Train Set Loss: 25.243783950805664\n",
            "Test Set Loss: 24.81044578552246\n",
            "Train Set Accuracy: 0.390625\n",
            "Test Set Accuracy: 0.296875\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 200\n",
            "Train Set Loss: 26.74785614013672\n",
            "Test Set Loss: 24.078622817993164\n",
            "Train Set Accuracy: 0.3359375\n",
            "Test Set Accuracy: 0.390625\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 250\n",
            "Train Set Loss: 21.86463737487793\n",
            "Test Set Loss: 20.568628311157227\n",
            "Train Set Accuracy: 0.3515625\n",
            "Test Set Accuracy: 0.3515625\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 300\n",
            "Train Set Loss: 20.57036590576172\n",
            "Test Set Loss: 20.193082809448242\n",
            "Train Set Accuracy: 0.3515625\n",
            "Test Set Accuracy: 0.3984375\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 350\n",
            "Train Set Loss: 19.613008499145508\n",
            "Test Set Loss: 21.873371124267578\n",
            "Train Set Accuracy: 0.3671875\n",
            "Test Set Accuracy: 0.3828125\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 400\n",
            "Train Set Loss: 18.950220108032227\n",
            "Test Set Loss: 19.205629348754883\n",
            "Train Set Accuracy: 0.3359375\n",
            "Test Set Accuracy: 0.3359375\n",
            "\n",
            "\n",
            "Epoch 0, Minibatch 450\n",
            "Train Set Loss: 19.0748291015625\n",
            "Test Set Loss: 19.952434539794922\n",
            "Train Set Accuracy: 0.3515625\n",
            "Test Set Accuracy: 0.234375\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 32\n",
            "Train Set Loss: 17.7125301361084\n",
            "Test Set Loss: 18.93415641784668\n",
            "Train Set Accuracy: 0.3515625\n",
            "Test Set Accuracy: 0.34375\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 82\n",
            "Train Set Loss: 18.431243896484375\n",
            "Test Set Loss: 18.173742294311523\n",
            "Train Set Accuracy: 0.46875\n",
            "Test Set Accuracy: 0.3203125\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 132\n",
            "Train Set Loss: 20.278743743896484\n",
            "Test Set Loss: 18.996170043945312\n",
            "Train Set Accuracy: 0.328125\n",
            "Test Set Accuracy: 0.34375\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 182\n",
            "Train Set Loss: 17.771526336669922\n",
            "Test Set Loss: 17.510616302490234\n",
            "Train Set Accuracy: 0.2890625\n",
            "Test Set Accuracy: 0.234375\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 232\n",
            "Train Set Loss: 16.881547927856445\n",
            "Test Set Loss: 17.74418067932129\n",
            "Train Set Accuracy: 0.265625\n",
            "Test Set Accuracy: 0.3125\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 282\n",
            "Train Set Loss: 18.29578399658203\n",
            "Test Set Loss: 18.975635528564453\n",
            "Train Set Accuracy: 0.375\n",
            "Test Set Accuracy: 0.265625\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 332\n",
            "Train Set Loss: 19.05307960510254\n",
            "Test Set Loss: 19.7917423248291\n",
            "Train Set Accuracy: 0.3203125\n",
            "Test Set Accuracy: 0.25\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 382\n",
            "Train Set Loss: 16.641101837158203\n",
            "Test Set Loss: 17.99464988708496\n",
            "Train Set Accuracy: 0.265625\n",
            "Test Set Accuracy: 0.28125\n",
            "\n",
            "\n",
            "Epoch 1, Minibatch 432\n",
            "Train Set Loss: 15.80801010131836\n",
            "Test Set Loss: 16.729612350463867\n",
            "Train Set Accuracy: 0.265625\n",
            "Test Set Accuracy: 0.2890625\n",
            "\n",
            "\n",
            "Epoch 2, Minibatch 14\n",
            "Train Set Loss: 19.15610122680664\n",
            "Test Set Loss: 16.361337661743164\n",
            "Train Set Accuracy: 0.3359375\n",
            "Test Set Accuracy: 0.3203125\n",
            "\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-23-47d0b9ea8a73>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     33\u001b[0m         \u001b[0;31m# Test set\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m         \u001b[0mtest_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitertools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m         \u001b[0mtestdata_it\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtesttargets_it\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m         \u001b[0mtestdata_it\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtestdata_it\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     37\u001b[0m         \u001b[0mtesttargets_it\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtesttargets_it\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    559\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    560\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    562\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    563\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     45\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m    132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m             \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m     58\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m             \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     61\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1045\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1046\u001b[0;31m         \u001b[0mforward_call\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1047\u001b[0m         \u001b[0;31m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1048\u001b[0m         \u001b[0;31m# this function, and just call forward.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    }
  ]
}